#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import math
from .base_scheduler import BaseSchedulerPlugin


class LinearSchedulerPlugin(BaseSchedulerPlugin):
    """线性学习率调度器插件"""
    
    def get_lr(self, current_step, total_steps, base_lr):
        """
        使用线性衰减公式计算当前学习率
        
        Args:
            current_step (int): 当前步数
            total_steps (int): 总步数
            base_lr (float): 基础学习率
            
        Returns:
            float: 当前学习率
        """
        # 线性衰减: 从base_lr到base_lr/10
        return base_lr - (base_lr - base_lr / 10) * (current_step / total_steps)